import numpy as np
from typing import List, Tuple, Dict

def object_generator(min_distance):
    rows = np.random.randint(20, 31)
    cols = np.random.randint(20, 31)
    return create_object_based_input(
        rows=rows,
        cols=cols,
        min_objects=3,
        max_objects=7,
        max_height=6,
        max_width=6,
        max_colors=9,
        min_distance=min_distance,
    )
    
def create_object_based_input(
    rows: int,
    cols: int,
    min_objects: int,
    max_objects: int,
    max_height: int,
    max_width: int,
    max_colors: int,
    min_distance: int = 1
) -> Tuple[np.ndarray, Dict]:
    
    
    grid = np.zeros((rows, cols), dtype=int)
    objects = []
    
    # List of object creation functions
    object_functions = [
        create_flower,
        create_star,
        create_cross,
        create_heart,
        create_spiral,
        create_diamond,
        create_bullseye,
        create_smiley
    ]
    
    # Try to place objects until we reach the minimum or can't place any more
    while len(objects) < min_objects:
        if not try_place_object(grid, objects, object_functions, max_height, max_width, max_colors, min_distance):
            break
    
    # Continue placing objects until we reach the maximum or can't place any more
    while len(objects) < max_objects:
        if not try_place_object(grid, objects, object_functions, max_height, max_width, max_colors, min_distance):
            break
    
    # Calculate object sizes and rank them
    object_sizes = []
    for top, left, height, width, color in objects:
        size = np.sum(grid[top:top+height, left:left+width] != 0)
        object_sizes.append((size, left, top))
    
    # Sort objects by size (descending), then by left (ascending), then by top (ascending)
    sorted_indices = sorted(range(len(object_sizes)), 
                            key=lambda i: (-object_sizes[i][0], object_sizes[i][1], object_sizes[i][2]))
    
    # Create ranking (1-based index)
    size_ranking = [0] * len(objects)
    for rank, index in enumerate(sorted_indices, 1):
        size_ranking[index] = rank
    
    # Create the extra dictionary
    extra = {
        'objects': objects,
        'size_ranking': size_ranking
    }
    
    return grid, extra

def try_place_object(
    grid: np.ndarray,
    objects: List[Tuple[int, int, int, int, int]],
    object_functions: List,
    max_height: int,
    max_width: int,
    max_colors: int,
    min_distance: int = 1
) -> bool:
    rows, cols = grid.shape
    max_attempts = 100
    
    for _ in range(max_attempts):
        height = np.random.randint(3, min(max_height, rows) + 1)
        width = np.random.randint(3, min(max_width, cols) + 1)
        top = np.random.randint(0, rows - height + 1)
        left = np.random.randint(0, cols - width + 1)
        
        # Check if the new object is at least min_distance away from existing objects
        if any(
            rect_distance(top, left, height, width, obj[0], obj[1], obj[2], obj[3]) < min_distance
            for obj in objects
        ):
            continue
        
        # Create the object
        color = np.random.randint(1, max_colors)
        object_function = np.random.choice(object_functions)
        object_grid = object_function(height, width, color)
        
        # Find the actual top and left of the object
        non_zero = np.nonzero(object_grid)
        if len(non_zero[0]) == 0:  # If the object is empty, skip it
            continue
        actual_top = non_zero[0].min()
        actual_left = non_zero[1].min()
        actual_height = non_zero[0].max() - actual_top + 1
        actual_width = non_zero[1].max() - actual_left + 1
        
        # Adjust the placement of the object
        adjusted_top = top + actual_top
        adjusted_left = left + actual_left
        
        # Check if the adjusted object is still within the grid
        if adjusted_top + actual_height > rows or adjusted_left + actual_width > cols:
            continue
        
        # Place the object on the grid
        grid[adjusted_top:adjusted_top+actual_height, adjusted_left:adjusted_left+actual_width] = object_grid[actual_top:actual_top+actual_height, actual_left:actual_left+actual_width]
        objects.append((adjusted_top, adjusted_left, actual_height, actual_width, color))
        return True
    
    return False

def rect_distance(top1: int, left1: int, height1: int, width1: int, 
                  top2: int, left2: int, height2: int, width2: int) -> int:
    left = max(left1, left2)
    right = min(left1 + width1, left2 + width2)
    top = max(top1, top2)
    bottom = min(top1 + height1, top2 + height2)
    
    horizontal_distance = max(0, left - right)
    vertical_distance = max(0, top - bottom)
    
    return max(horizontal_distance, vertical_distance)

def rect_overlap(top1, left1, height1, width1, top2, left2, height2, width2):
    return (left1 < left2 + width2 and left1 + width1 > left2 and
            top1 < top2 + height2 and top1 + height1 > top2)

def create_flower(height, width, color):
    grid = np.zeros((height, width), dtype=int)
    center_y, center_x = height // 2, width // 2
    radius = min(height, width) // 2 - 1
    
    for y in range(height):
        for x in range(width):
            dist = np.sqrt((y - center_y)**2 + (x - center_x)**2)
            angle = np.arctan2(y - center_y, x - center_x)
            if dist <= radius * (1 + 0.3 * np.sin(5 * angle)):
                grid[y, x] = color
    
    return grid

def create_star(height, width, color):
    grid = np.zeros((height, width), dtype=int)
    center_y, center_x = height // 2, width // 2
    outer_radius = min(height, width) // 2 - 1
    inner_radius = outer_radius // 2
    
    for y in range(height):
        for x in range(width):
            dx, dy = x - center_x, y - center_y
            dist = np.sqrt(dx**2 + dy**2)
            angle = np.arctan2(dy, dx) % (2 * np.pi)
            if dist <= outer_radius * (1 - 0.5 * (angle % (2 * np.pi / 5) > np.pi / 5)):
                grid[y, x] = color
    
    return grid

def create_cross(height, width, color):
    grid = np.zeros((height, width), dtype=int)
    thickness = min(height, width) // 3
    
    grid[height//2-thickness//2:height//2+thickness//2+1, :] = color
    grid[:, width//2-thickness//2:width//2+thickness//2+1] = color
    
    return grid

def create_heart(height, width, color):
    grid = np.zeros((height, width), dtype=int)
    
    for y in range(height):
        for x in range(width):
            nx, ny = (x - width/2) / (width/2), (y - height/2) / (height/2)
            if (nx*nx + ny*ny - 1)**3 - nx*nx*ny*ny*ny <= 0:
                grid[y, x] = color
    
    return grid

def create_spiral(height, width, color):
    grid = np.zeros((height, width), dtype=int)
    center_y, center_x = height // 2, width // 2
    max_radius = min(height, width) // 2
    
    for r in np.linspace(0, max_radius, num=100):
        theta = r * 10
        x = int(center_x + r * np.cos(theta))
        y = int(center_y + r * np.sin(theta))
        if 0 <= x < width and 0 <= y < height:
            grid[y, x] = color
    
    return grid

def create_diamond(height, width, color):
    grid = np.zeros((height, width), dtype=int)
    center_y, center_x = height // 2, width // 2
    
    for y in range(height):
        for x in range(width):
            if abs(x - center_x) + abs(y - center_y) <= min(height, width) // 2:
                grid[y, x] = color
    
    return grid

def create_bullseye(height, width, color):
    grid = np.zeros((height, width), dtype=int)
    center_y, center_x = height // 2, width // 2
    max_radius = min(height, width) // 2
    
    for y in range(height):
        for x in range(width):
            dist = np.sqrt((y - center_y)**2 + (x - center_x)**2)
            if int(dist / (max_radius / 3)) % 2 == 0:
                grid[y, x] = color
    
    return grid

def create_smiley(height, width, color):
    grid = np.zeros((height, width), dtype=int)
    center_y, center_x = height // 2, width // 2
    radius = min(height, width) // 2 - 1
    
    # Face
    for y in range(height):
        for x in range(width):
            if (x - center_x)**2 + (y - center_y)**2 <= radius**2:
                grid[y, x] = color
    
    # Eyes (now using color instead of black)
    eye_radius = radius // 4
    left_eye_center = (center_y - radius//3, center_x - radius//3)
    right_eye_center = (center_y - radius//3, center_x + radius//3)
    
    for eye_center in [left_eye_center, right_eye_center]:
        for y in range(max(0, eye_center[0] - eye_radius), min(height, eye_center[0] + eye_radius + 1)):
            for x in range(max(0, eye_center[1] - eye_radius), min(width, eye_center[1] + eye_radius + 1)):
                if (x - eye_center[1])**2 + (y - eye_center[0])**2 <= eye_radius**2:
                    grid[y, x] = color  # Changed from 0 to color
    
    # Smile (now using color instead of black)
    smile_center_y = center_y + radius // 3
    for x in range(center_x - radius//2, center_x + radius//2 + 1):
        y = smile_center_y + int(np.sqrt(max(0, (radius//2)**2 - (x - center_x)**2)) // 2)
        if 0 <= y < height:
            grid[y, x] = color  # Changed from 0 to color
    
    return grid


import numpy as np
from typing import List, Tuple, Dict

def create_shape_with_holes(height: int, width: int, color: int) -> Tuple[np.ndarray, List[Tuple[int, int, int, int]]]:
    shape = np.full((height, width), color, dtype=int)
    holes = []
    
    # Ensure the shape is large enough for holes
    if height < 6 or width < 6:
        return shape, holes

    max_holes = min(3, (height - 2) // 2, (width - 2) // 2)
    if max_holes < 1:
        return shape, holes

    num_holes = np.random.randint(1, max_holes + 1)
    for _ in range(num_holes):
        for attempt in range(50):  # Limit attempts to avoid infinite loop
            hole_height = np.random.randint(2, min(height // 2, height - 3) + 1)
            hole_width = np.random.randint(2, min(width // 2, width - 3) + 1)
            
            if hole_height >= height - 2 or hole_width >= width - 2:
                continue

            hole_top = np.random.randint(1, height - hole_height - 1)
            hole_left = np.random.randint(1, width - hole_width - 1)
            
            # Check if the hole overlaps with existing holes
            if any(rect_overlap(hole_top, hole_left, hole_height, hole_width, h[0], h[1], h[2], h[3]) for h in holes):
                continue
            
            shape[hole_top:hole_top+hole_height, hole_left:hole_left+hole_width] = 0
            holes.append((hole_top, hole_left, hole_height, hole_width))
            break
    
    return shape, holes

def rect_overlap(top1, left1, height1, width1, top2, left2, height2, width2):
    return (left1 < left2 + width2 and left1 + width1 > left2 and
            top1 < top2 + height2 and top1 + height1 > top2)

def object_generator_with_holes(min_distance: int = 2):
    rows = np.random.randint(20, 31)
    cols = np.random.randint(20, 31)
    return create_object_based_input_with_holes(
        rows=rows,
        cols=cols,
        min_objects=3,
        max_objects=5,
        min_height=6,
        max_height=12,
        min_width=6,
        max_width=12,
        max_colors=9,
        min_distance=min_distance,
    )

import numpy as np
from typing import List, Tuple, Dict

class ObjectPlacementError(Exception):
    """Exception raised when unable to place the minimum number of objects."""
    pass

def create_object_based_input_with_holes(
    rows: int,
    cols: int,
    min_objects: int,
    max_objects: int,
    min_height: int,
    max_height: int,
    min_width: int,
    max_width: int,
    max_colors: int,
    min_distance: int = 2
) -> Tuple[np.ndarray, Dict]:
    grid = np.zeros((rows, cols), dtype=int)
    objects = []
    object_holes = {}

    def try_place_object():
        for _ in range(100):  # Max attempts
            height = np.random.randint(min_height, min(max_height, rows) + 1)
            width = np.random.randint(min_width, min(max_width, cols) + 1)
            top = np.random.randint(0, rows - height + 1)
            left = np.random.randint(0, cols - width + 1)

            if any(
                rect_overlap(top, left, height, width, obj[0], obj[1], obj[2], obj[3])
                or abs(top - obj[0]) < min_distance
                or abs(left - obj[1]) < min_distance
                for obj in objects
            ):
                continue

            color = np.random.randint(1, max_colors + 1)
            shape, holes = create_shape_with_holes(height, width, color)
            
            grid[top:top+height, left:left+width] = shape
            objects.append((top, left, height, width, color))
            if holes:  # Only add to object_holes if there are actually holes
                object_holes[len(objects) - 1] = holes
            return True
        return False

    for _ in range(max_objects):  # Attempt to place up to max_objects
        if not try_place_object():
            break

    if len(objects) < min_objects:
        raise ObjectPlacementError(f"Failed to place minimum number of objects. Placed {len(objects)}, minimum required: {min_objects}")

    extra = {
        'objects': objects,
        'object_holes': object_holes
    }

    return grid, extra

# The rest of the code (create_shape_with_holes, rect_overlap, etc.) remains the same
import numpy as np
from typing import List, Tuple, Dict
import numpy as np
from typing import List, Tuple, Dict

def create_shape_with_large_holes(min_size: int, max_size: int, color: int) -> Tuple[np.ndarray, List[Tuple[int, int, int, int]]]:
    size = np.random.randint(min_size, max_size + 1)
    shape = np.zeros((size, size), dtype=int)
    center = size // 2

    # Create the outer shape
    shape_type = np.random.choice(['diamond', 'circle', 'star', 'square'])
    
    if shape_type == 'diamond':
        for y in range(size):
            for x in range(size):
                if abs(x - center) + abs(y - center) <= center:
                    shape[y, x] = color

    elif shape_type == 'circle':
        for y in range(size):
            for x in range(size):
                if (x - center)**2 + (y - center)**2 <= center**2:
                    shape[y, x] = color

    elif shape_type == 'star':
        for y in range(size):
            for x in range(size):
                angle = np.arctan2(y - center, x - center)
                distance = np.sqrt((x - center)**2 + (y - center)**2)
                if distance <= center * (0.5 + 0.5 * np.cos(angle * 5)):
                    shape[y, x] = color

    elif shape_type == 'square':
        shape.fill(color)

    # Create holes
    holes = []
    max_hole_size = max(size // 2 - 1, 2)  # Allow for larger holes
    num_holes = np.random.randint(1, 4)  # 1 to 3 holes
    for _ in range(num_holes):
        hole_size = np.random.randint(1, max_hole_size + 1)
        attempts = 0
        while attempts < 50:  # Limit attempts to avoid infinite loop
            hole_top = np.random.randint(1, size - hole_size - 1)
            hole_left = np.random.randint(1, size - hole_size - 1)
            
            # Check if we can create a hole without breaking the shape
            if np.all(shape[max(0, hole_top-1):min(size, hole_top+hole_size+1), 
                            max(0, hole_left-1):min(size, hole_left+hole_size+1)] == color):
                shape[hole_top:hole_top+hole_size, hole_left:hole_left+hole_size] = 0
                holes.append((hole_top, hole_left, hole_size, hole_size))
                break
            attempts += 1

    return shape, holes

def place_shapes_until_full(rows: int, cols: int, max_colors: int, min_shape_size: int, max_shape_size: int) -> Tuple[np.ndarray, Dict]:
    grid = np.zeros((rows, cols), dtype=int)
    objects = []
    object_holes = {}
    attempts = 0
    max_attempts = rows * cols  # Upper bound on attempts

    while attempts < max_attempts:
        shape, holes = create_shape_with_large_holes(min_shape_size, max_shape_size, np.random.randint(1, max_colors + 1))
        shape_height, shape_width = shape.shape

        # Try to place the shape
        for _ in range(50):  # 50 attempts to place each shape
            if shape_height > rows or shape_width > cols:
                break
            top = np.random.randint(0, rows - shape_height + 1)
            left = np.random.randint(0, cols - shape_width + 1)

            if np.all(grid[top:top+shape_height, left:left+shape_width] == 0):
                grid[top:top+shape_height, left:left+shape_width] = shape
                obj_index = len(objects)
                objects.append((top, left, shape_height, shape_width, shape.max()))
                if holes:
                    object_holes[obj_index] = holes
                break
        else:
            # If we couldn't place the shape after 50 attempts, try a new shape
            attempts += 1
            continue

        attempts = 0  # Reset attempts counter after successful placement

    extra = {
        'objects': objects,
        'object_holes': object_holes
    }
    return grid, extra

def diverse_shapes_generator():
    rows = np.random.randint(20, 31)
    cols = np.random.randint(20, 31)
    return place_shapes_until_full(rows, cols, max_colors=5, min_shape_size=5, max_shape_size=15)

# Example usage:
# grid, extra = diverse_shapes_generator()



def generate_half_objects_input(max_size: int = 30, min_objects: int = 2, max_objects: int = 5) -> Tuple[np.ndarray, Dict]:
    rows = np.random.randint(20, max_size + 1)
    cols = np.random.randint(20, max_size + 1)
    grid = np.zeros((rows, cols), dtype=int)
    
    num_objects = np.random.randint(min_objects, max_objects + 1)
    objects = []
    half_objects = []
    mirror_axes = []
    
    for _ in range(num_objects):
        color = np.random.randint(1, 10)
        height = np.random.randint(4, 9)
        width = np.random.randint(4, 9)
        
        # Ensure objects are placed with enough space for mirroring
        attempts = 0
        while attempts < 100:
            top = np.random.randint(0, rows - height)
            left = np.random.randint(0, cols - width)
            
            if np.all(grid[max(0, top-width):min(rows, top+height+width), 
                           max(0, left-height):min(cols, left+width+height)] == 0):
                break
            attempts += 1
        
        if attempts == 100:
            continue  # Skip this object if we can't place it
        
        object_grid = np.random.choice([0, color], size=(height, width), p=[0.3, 0.7])
        grid[top:top+height, left:left+width] = object_grid
        objects.append((top, left, height, width, color))
        
        # Randomly choose horizontal or vertical mirroring
        is_horizontal = True
        if is_horizontal:
            half_height = height // 2
            half_object = object_grid[:half_height, :]
            half_objects.append((top, left, half_height, width, color))
            mirror_axes.append(('horizontal', top + half_height))
        else:
            half_width = width // 2
            half_object = object_grid[:, :half_width]
            half_objects.append((top, left, height, half_width, color))
            mirror_axes.append(('vertical', left + half_width))
        
        grid[top:top+height, left:left+width] = 0
        grid[top:top+half_object.shape[0], left:left+half_object.shape[1]] = half_object
    
    extra = {
        'objects': objects,
        'half_objects': half_objects,
        'mirror_axes': mirror_axes
    }
    
    return grid, extra


def generate_half_objects_input_vert(max_size: int = 30, min_objects: int = 2, max_objects: int = 5) -> Tuple[np.ndarray, Dict]:
    rows = np.random.randint(20, max_size + 1)
    cols = np.random.randint(20, max_size + 1)
    grid = np.zeros((rows, cols), dtype=int)
    
    num_objects = np.random.randint(min_objects, max_objects + 1)
    objects = []
    half_objects = []
    mirror_axes = []
    
    for _ in range(num_objects):
        color = np.random.randint(1, 10)
        height = np.random.randint(4, 9)
        width = np.random.randint(4, 9)
        
        # Ensure objects are placed with enough space for mirroring
        attempts = 0
        while attempts < 100:
            top = np.random.randint(0, rows - height)
            left = np.random.randint(0, cols - width)
            
            if np.all(grid[max(0, top-width):min(rows, top+height+width), 
                           max(0, left-height):min(cols, left+width+height)] == 0):
                break
            attempts += 1
        
        if attempts == 100:
            continue  # Skip this object if we can't place it
        
        object_grid = np.random.choice([0, color], size=(height, width), p=[0.3, 0.7])
        grid[top:top+height, left:left+width] = object_grid
        objects.append((top, left, height, width, color))
        
        # Randomly choose horizontal or vertical mirroring
        is_horizontal = False
        if is_horizontal:
            half_height = height // 2
            half_object = object_grid[:half_height, :]
            half_objects.append((top, left, half_height, width, color))
            mirror_axes.append(('horizontal', top + half_height))
        else:
            half_width = width // 2
            half_object = object_grid[:, :half_width]
            half_objects.append((top, left, height, half_width, color))
            mirror_axes.append(('vertical', left + half_width))
        
        grid[top:top+height, left:left+width] = 0
        grid[top:top+half_object.shape[0], left:left+half_object.shape[1]] = half_object
    
    extra = {
        'objects': objects,
        'half_objects': half_objects,
        'mirror_axes': mirror_axes
    }
    
    return grid, extra



import numpy as np
from typing import Tuple, Dict
from functools import partial


def noisy_object_generator(max_size: int = 30, min_objects: int = 2, max_objects: int = 5, noise_density: float = 0.1) -> Tuple[np.ndarray, Dict]:
    # Generate the original grid using object_generator
    original_grid, original_extras = object_generator(min_distance=2)
    
    rows, cols = original_grid.shape
    noisy_grid = original_grid.copy()
    
    # Ensure objects are larger than 1 pixel
    for obj in original_extras['objects']:
        top, left, height, width, color = obj
        if height == 1 or width == 1:
            new_height = max(2, height)
            new_width = max(2, width)
            original_grid[top:top+new_height, left:left+new_width] = color
            noisy_grid[top:top+new_height, left:left+new_width] = color
    
    # Create a mask for potential noise positions
    potential_noise_mask = np.ones_like(original_grid, dtype=bool)
    
    # Mark object pixels and their borders as unavailable for noise
    for i in range(rows):
        for j in range(cols):
            if original_grid[i, j] != 0:
                for di in [-1, 0, 1]:
                    for dj in [-1, 0, 1]:
                        if 0 <= i+di < rows and 0 <= j+dj < cols:
                            potential_noise_mask[i+di, j+dj] = False
    
    # Add noise pixels
    num_noise_pixels = int(noise_density * rows * cols)
    for _ in range(num_noise_pixels):
        available_positions = np.where(potential_noise_mask)
        if len(available_positions[0]) == 0:
            break
        
        idx = np.random.randint(len(available_positions[0]))
        i, j = available_positions[0][idx], available_positions[1][idx]
        
        noisy_grid[i, j] = np.random.randint(1, 10)  # Random color for noise
        
        # Mark surrounding pixels as unavailable for noise
        for di in [-1, 0, 1]:
            for dj in [-1, 0, 1]:
                if 0 <= i+di < rows and 0 <= j+dj < cols:
                    potential_noise_mask[i+di, j+dj] = False
    
    # Store the original grid in extras
    extras = {
        'original_grid': original_grid,
        'original_extras': original_extras
    }
    
    return noisy_grid, extras